import logging

def train(model, x, data, sens, optimizer, critereon, device="cuda", use_sens_as_labels=False):
    model.train()

    return train_full_batch(model, x, data, sens, optimizer, critereon, use_sens_as_labels)

def train_full_batch(model, x, data, sens, optimizer, critereon, use_sens_as_labels):
    model.train()

    optimizer.zero_grad()
    y_pred = model(x, data.edge_index, data.edge_attr.reshape(-1, 1))[data.train_mask]
    if use_sens_as_labels:
        y_true = sens[data.train_mask].squeeze()
    else:
        y_true = data.y[data.train_mask].squeeze()

    loss = critereon(y_pred, y_true)
    loss.backward()
    optimizer.step()

    return loss
